# -*- coding: utf-8 -*-
"""
Created on Wed Aug 31 16:06:34 2022

@author: wulong
"""
import numpy as np
from casadi import *
from Slow_model import*

# In[1] Slow EMPC
def formulate_opt_s(initial_state, state_guess, state_bnds_lo, state_bnds_up,
                    fast_state_guess, fast_state_bnds_lo, fast_state_bnds_up, 
                    input_guess, input_bnds_lo, input_bnds_up, 
                    fast_input_guess, fast_input_bnds_lo, fast_input_bnds_up, 
                    input_bin_bnds,
                    distb_bnds,
                    output_guess,
                    yspe_guess, yspe_bnds_lo, yspe_bnds_up,
                    output_bnds_lo, output_bnds_up, long_state_setpnts,  
                    alpha, pred_horzn, pf, pmg, pse,
                    X_int_f,
                    i, del_Uc, del_Ucf, Uc_0, Ucf_0, Uz_0,
                    pcm, ppn, rxi, rxi_as, yeb):
    w = []
    w0 = []
    lbw = []
    ubw = []
    J = 0
    g = []
    lbg = []
    ubg = []
    
    # "Lift" initial states conditions
    Xk = initial_state
    
    # fast state initial conditions
    Xfk = X_int_f
    
    # Have last time inputs values
    past_u = Uc_0
    past_uf = Ucf_0
    past_uz_s = vertcat(Uz_0[1], Uz_0[2], 1)
    past_uz_f = vertcat(Uz_0[0], Uz_0[1], Uz_0[2], 1)
    
    del_bns_u = del_Uc*Delta_s
    del_bns_uf = del_Ucf*Delta_s
    
    for k in range(1, pred_horzn+1):
        # Define binary for inputs
        Uzk = input_bin_bnds[k-1,:]
        Zu_s = vertcat(Uzk[1], Uzk[2], 1)
        Zu_f = vertcat(Uzk[0], Uzk[1], Uzk[2], 1)
        
        # Create input variables
        uname = 'U' + str(k-1)
        Uk = MX.sym(uname, Nuc_s)
        w += [Uk]
        lbw +=[Zu_s*input_bnds_lo]
        ubw +=[Zu_s*input_bnds_up]
        w0 += [input_guess[k-1,:]]
        
        ufname = 'U_f' + str(k-1)
        Ufk = MX.sym(ufname, Nuc_f)
        w += [Ufk]
        lbw +=[Zu_f*fast_input_bnds_lo]
        ubw +=[Zu_f*fast_input_bnds_up]
        w0 += [fast_input_guess[k-1,:]]
        
        # increment constraints on inputs 
        # between two time instants
        # u(k) - u(k-1)
        if i != 0 or k != 1:
            del_Zu_s = fabs(Zu_s - past_uz_s)
            del_Zu_f = fabs(Zu_f - past_uz_f)
            
            g += [Uk - past_u]
            lbg += [-1*(del_bns_u + del_Zu_s*1e5)]
            ubg += [del_bns_u + del_Zu_s*1e5]
            
            g += [Ufk - past_uf]
            lbg += [-1*(del_bns_uf + del_Zu_f*1e5)]
            ubg += [del_bns_uf + del_Zu_f*1e5]
            pass
        
        # Create disturbance
        Dk = distb_bnds[k-1,:]
        
        # Simulate the model
        Ik = I_ode_s(x0 = Xk, p = vertcat(Xfk, Uk, Ufk, Uzk, Dk))
        X_int = Ik['xf']
        
        # Create new states variables
        xname = 'X' + str(k)
        Xk = MX.sym(xname, Nx_s)
        w += [Xk]
        lbw += [state_bnds_lo]
        ubw += [state_bnds_up]
        w0 += [state_guess[k-1,:]]
        
        # Add dynamic constraints
        g += [X_int - Xk]
        lbg += [[0]*Nx_s]
        ubg += [[0]*Nx_s]
        
        # Create fast states variables
        xfname = 'X_f' + str(k)
        Xfk = MX.sym(xfname, Nx_f)
        w += [Xfk]
        lbw += [fast_state_bnds_lo]
        ubw += [fast_state_bnds_up]
        w0 += [fast_state_guess[k-1,:]]
        
        g += [const_ies_s(Xk, Xfk, Uk, Ufk, Uzk, Dk)]
        lbg += [[0]*Nx_f]
        ubg += [[0]*Nx_f]
        
        # Create outputs variables
        yname = 'y' + str(k)
        Yk = MX.sym(yname, Ny_s)
        w += [Yk]
        lbw += [output_bnds_lo]
        ubw += [output_bnds_up]
        w0 += [output_guess[k-1,:]]
        
        # Add ouput constranints
        g += [Yk - out_ies_s(Xk, Xfk, Uk, Ufk, Uzk, Dk)]
        lbg += [[0]*Ny_s]
        ubg += [[0]*Ny_s]
        
        # Create outputs zone variables
        yname_spe = 'y_spe' + str(k)
        Yk_spe = MX.sym(yname_spe, 1)
        w += [Yk_spe]
        lbw += [yspe_bnds_lo[k-1]]
        ubw += [yspe_bnds_up[k-1]]
        w0 += [yspe_guess[k-1,:]]
        
        # Sell/Compensation/Penalty Price, Regulation factor
        pse_k = pse[k-1]
        pcm_k = pcm[k-1]
        ppn_k = ppn[k-1]
        rxi_k = rxi[k-1]
        rxi_as_k = rxi_as[k-1]
        yeb_k = yeb[k-1]
        
        # The cost function
        J += alpha[0] * (Yk[0] - (1 + rxi_k)*yeb_k)**2
        J += alpha[1] * (Yk[1] - Yk_spe)**2
        J += alpha[2] * (pf*(Ufk[0] + Ufk[1]) + ppn_k*(Yk[0] - (1 + rxi_k)*yeb_k)**2
                         - pmg*Dk[2] - pse_k*Yk[0] - pcm_k*rxi_as_k*yeb_k)
        J += alpha[3] * (Xk[0] - long_state_setpnts[k-1,0])**2
        J += alpha[4] * (Xk[1] - long_state_setpnts[k-1,1])**2
                
        past_u = Uk
        past_uf = Ufk
        past_uz_s = Zu_s
        past_uz_f = Zu_f
        
        pass
    
    # Concatenate decision variables and constraint terms
    w = vertcat(*w)
    lbw = vertcat(*lbw)
    ubw = vertcat(*ubw)
    w0 = vertcat(*w0)
    g = vertcat(*g)
    lbg = vertcat(*lbg)
    ubg = vertcat(*ubg)
    
    return w, lbw, ubw, w0, g, lbg, ubg, J

# In[2] Creat Slow EMPC solver
def solve_opt_s(w, lbw, ubw, w0, g, lbg, ubg, J):
    print('Creat an S-EMPC solver')
    nlp_prob = {'f': J, 'x': w, 'g': g}
    nlp_solver = nlpsol('nlp_solver', 'ipopt', nlp_prob)
    # Solve NLP
    print('Solve S-EMPC')
    sol = nlp_solver(x0=w0, lbx=lbw, ubx=ubw, lbg=lbg, ubg=ubg)
    print(nlp_solver.stats())
    optimalValues = sol['x'].full().ravel()
    
    return optimalValues
